// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Mehdi Goli    Codeplay Software Ltd.
// Ralph Potter  Codeplay Software Ltd.
// Luke Iwanski  Codeplay Software Ltd.
// Contact: <eigen@codeplay.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.

/*****************************************************************
 * TensorReductionSycl.h
 *
 * \brief:
 *  This is the specialization of the reduction operation. Two phase reduction approach 
 * is used since the GPU does not have Global Synchronization for global memory among 
 * different work-group/thread block. To solve the problem, we need to create two kernels 
 * to reduce the data, where the first kernel reduce the data locally and each local 
 * workgroup/thread-block save the input data into global memory. In the second phase (global reduction)
 * one work-group uses one work-group/thread-block to reduces the intermediate data into one single element. 
 * Here is an NVIDIA presentation explaining the optimized two phase reduction algorithm on GPU:
 * https://developer.download.nvidia.com/assets/cuda/files/reduction.pdf
 *
 *****************************************************************/

#ifndef UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSOR_REDUCTION_SYCL_HPP
#define UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSOR_REDUCTION_SYCL_HPP
namespace Eigen {
namespace TensorSycl {
    namespace internal {

        template <typename Op, typename CoeffReturnType, typename Index, bool Vectorizable> struct OpDefiner
        {
            typedef typename Vectorise<CoeffReturnType, Eigen::SyclDevice, Vectorizable>::PacketReturnType PacketReturnType;
            typedef Op type;
            static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE type get_op(Op& op) { return op; }

            static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType finalise_op(const PacketReturnType& accumulator, const Index&) { return accumulator; }
        };

        template <typename CoeffReturnType, typename Index> struct OpDefiner<Eigen::internal::MeanReducer<CoeffReturnType>, CoeffReturnType, Index, false>
        {
            typedef Eigen::internal::SumReducer<CoeffReturnType> type;
            static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE type get_op(Eigen::internal::MeanReducer<CoeffReturnType>&) { return type(); }

            static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType finalise_op(const CoeffReturnType& accumulator, const Index& scale)
            {
                ::Eigen::internal::scalar_quotient_op<CoeffReturnType> quotient_op;
                return quotient_op(accumulator, CoeffReturnType(scale));
            }
        };

        template <typename CoeffReturnType, typename Index> struct OpDefiner<Eigen::internal::MeanReducer<CoeffReturnType>, CoeffReturnType, Index, true>
        {
            typedef typename Vectorise<CoeffReturnType, Eigen::SyclDevice, true>::PacketReturnType PacketReturnType;
            typedef Eigen::internal::SumReducer<CoeffReturnType> type;
            static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE type get_op(Eigen::internal::MeanReducer<CoeffReturnType>&) { return type(); }

            static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType finalise_op(const PacketReturnType& accumulator, const Index& scale)
            {
                return ::Eigen::internal::pdiv(accumulator, ::Eigen::internal::pset1<PacketReturnType>(CoeffReturnType(scale)));
            }
        };

        template <typename CoeffReturnType, typename OpType, typename InputAccessor, typename OutputAccessor, typename Index, Index local_range>
        struct SecondStepFullReducer
        {
            typedef cl::sycl::accessor<CoeffReturnType, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local> LocalAccessor;
            typedef OpDefiner<OpType, CoeffReturnType, Index, true> OpDef;
            typedef typename OpDef::type Op;
            LocalAccessor scratch;
            InputAccessor aI;
            OutputAccessor outAcc;
            Op op;
            SecondStepFullReducer(LocalAccessor scratch_, InputAccessor aI_, OutputAccessor outAcc_, OpType op_)
                : scratch(scratch_), aI(aI_), outAcc(outAcc_), op(OpDef::get_op(op_))
            {
            }

            void operator()(cl::sycl::nd_item<1> itemID)
            {
                // Our empirical research shows that the best performance will be achieved
                // when there is only one element per thread to reduce in the second step.
                // in this step the second step reduction time is almost negligible.
                // Hence, in the second step of reduction the input size is fixed to the
                // local size, thus, there is only one element read per thread. The
                // algorithm must be changed if the number of reduce per thread in the
                // second step is greater than 1. Otherwise, the result will be wrong.
                const Index localid = itemID.get_local_id(0);
                auto aInPtr = aI.get_pointer() + localid;
                auto aOutPtr = outAcc.get_pointer();
                CoeffReturnType* scratchptr = scratch.get_pointer();
                CoeffReturnType accumulator = *aInPtr;

                scratchptr[localid] = op.finalize(accumulator);
                for (Index offset = itemID.get_local_range(0) / 2; offset > 0; offset /= 2)
                {
                    itemID.barrier(cl::sycl::access::fence_space::local_space);
                    if (localid < offset)
                    {
                        op.reduce(scratchptr[localid + offset], &accumulator);
                        scratchptr[localid] = op.finalize(accumulator);
                    }
                }
                if (localid == 0)
                    *aOutPtr = op.finalize(accumulator);
            }
        };

        // Full reduction first phase. In this version the vectorization is true and the reduction accept
        // any generic reducerOp  e.g( max, min, sum, mean, iamax, iamin, etc ).
        template <typename Evaluator, typename OpType, typename Evaluator::Index local_range> class FullReductionKernelFunctor
        {
        public:
            typedef typename Evaluator::CoeffReturnType CoeffReturnType;
            typedef typename Evaluator::Index Index;
            typedef OpDefiner<OpType, typename Evaluator::CoeffReturnType, Index, (Evaluator::ReducerTraits::PacketAccess & Evaluator::InputPacketAccess)>
                OpDef;

            typedef typename OpDef::type Op;
            typedef typename Evaluator::EvaluatorPointerType EvaluatorPointerType;
            typedef typename Evaluator::PacketReturnType PacketReturnType;
            typedef typename ::Eigen::internal::
                conditional<(Evaluator::ReducerTraits::PacketAccess & Evaluator::InputPacketAccess), PacketReturnType, CoeffReturnType>::type OutType;
            typedef cl::sycl::accessor<OutType, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local> LocalAccessor;
            LocalAccessor scratch;
            Evaluator evaluator;
            EvaluatorPointerType final_output;
            Index rng;
            Op op;

            FullReductionKernelFunctor(LocalAccessor scratch_, Evaluator evaluator_, EvaluatorPointerType final_output_, Index rng_, OpType op_)
                : scratch(scratch_), evaluator(evaluator_), final_output(final_output_), rng(rng_), op(OpDef::get_op(op_))
            {
            }

            void operator()(cl::sycl::nd_item<1> itemID) { compute_reduction(itemID); }

            template <bool Vect = (Evaluator::ReducerTraits::PacketAccess & Evaluator::InputPacketAccess)>
            EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename ::Eigen::internal::enable_if<Vect>::type compute_reduction(const cl::sycl::nd_item<1>& itemID)
            {
                auto output_ptr = final_output.get_pointer();
                Index VectorizedRange = (rng / Evaluator::PacketSize) * Evaluator::PacketSize;
                Index globalid = itemID.get_global_id(0);
                Index localid = itemID.get_local_id(0);
                Index step = Evaluator::PacketSize * itemID.get_global_range(0);
                Index start = Evaluator::PacketSize * globalid;
                // vectorizable parts
                PacketReturnType packetAccumulator = op.template initializePacket<PacketReturnType>();
                for (Index i = start; i < VectorizedRange; i += step)
                { op.template reducePacket<PacketReturnType>(evaluator.impl().template packet<Unaligned>(i), &packetAccumulator); }
                globalid += VectorizedRange;
                // non vectorizable parts
                for (Index i = globalid; i < rng; i += itemID.get_global_range(0))
                {
                    op.template reducePacket<PacketReturnType>(
                        ::Eigen::TensorSycl::internal::PacketWrapper<PacketReturnType, Evaluator::PacketSize>::convert_to_packet_type(evaluator.impl().coeff(i),
                                                                                                                                      op.initialize()),
                        &packetAccumulator);
                }
                scratch[localid] = packetAccumulator = OpDef::finalise_op(op.template finalizePacket<PacketReturnType>(packetAccumulator), rng);
                // reduction parts // Local size is always power of 2
                EIGEN_UNROLL_LOOP
                for (Index offset = local_range / 2; offset > 0; offset /= 2)
                {
                    itemID.barrier(cl::sycl::access::fence_space::local_space);
                    if (localid < offset)
                    {
                        op.template reducePacket<PacketReturnType>(scratch[localid + offset], &packetAccumulator);
                        scratch[localid] = op.template finalizePacket<PacketReturnType>(packetAccumulator);
                    }
                }
                if (localid == 0)
                {
                    output_ptr[itemID.get_group(0)] = op.finalizeBoth(op.initialize(), op.template finalizePacket<PacketReturnType>(packetAccumulator));
                }
            }

            template <bool Vect = (Evaluator::ReducerTraits::PacketAccess & Evaluator::InputPacketAccess)>
            EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename ::Eigen::internal::enable_if<!Vect>::type compute_reduction(const cl::sycl::nd_item<1>& itemID)
            {
                auto output_ptr = final_output.get_pointer();
                Index globalid = itemID.get_global_id(0);
                Index localid = itemID.get_local_id(0);
                // vectorizable parts
                CoeffReturnType accumulator = op.initialize();
                // non vectorizable parts
                for (Index i = globalid; i < rng; i += itemID.get_global_range(0)) { op.reduce(evaluator.impl().coeff(i), &accumulator); }
                scratch[localid] = accumulator = OpDef::finalise_op(op.finalize(accumulator), rng);

                // reduction parts. the local size is always power of 2
                EIGEN_UNROLL_LOOP
                for (Index offset = local_range / 2; offset > 0; offset /= 2)
                {
                    itemID.barrier(cl::sycl::access::fence_space::local_space);
                    if (localid < offset)
                    {
                        op.reduce(scratch[localid + offset], &accumulator);
                        scratch[localid] = op.finalize(accumulator);
                    }
                }
                if (localid == 0)
                {
                    output_ptr[itemID.get_group(0)] = op.finalize(accumulator);
                }
            }
        };

        template <typename Evaluator, typename OpType> class GenericNondeterministicReducer
        {
        public:
            typedef typename Evaluator::CoeffReturnType CoeffReturnType;
            typedef typename Evaluator::EvaluatorPointerType EvaluatorPointerType;
            typedef typename Evaluator::Index Index;
            typedef OpDefiner<OpType, CoeffReturnType, Index, false> OpDef;
            typedef typename OpDef::type Op;
            template <typename Scratch>
            GenericNondeterministicReducer(Scratch,
                                           Evaluator evaluator_,
                                           EvaluatorPointerType output_accessor_,
                                           OpType functor_,
                                           Index range_,
                                           Index num_values_to_reduce_)
                : evaluator(evaluator_), output_accessor(output_accessor_), functor(OpDef::get_op(functor_)), range(range_),
                  num_values_to_reduce(num_values_to_reduce_)
            {
            }

            void operator()(cl::sycl::nd_item<1> itemID)
            {
                auto output_accessor_ptr = output_accessor.get_pointer();
                /// const cast added as a naive solution to solve the qualifier drop error
                Index globalid = static_cast<Index>(itemID.get_global_linear_id());
                if (globalid < range)
                {
                    CoeffReturnType accum = functor.initialize();
                    Eigen::internal::GenericDimReducer<Evaluator::NumReducedDims - 1, Evaluator, Op>::reduce(
                        evaluator, evaluator.firstInput(globalid), functor, &accum);
                    output_accessor_ptr[globalid] = OpDef::finalise_op(functor.finalize(accum), num_values_to_reduce);
                }
            }

        private:
            Evaluator evaluator;
            EvaluatorPointerType output_accessor;
            Op functor;
            Index range;
            Index num_values_to_reduce;
        };

        enum class reduction_dim
        {
            inner_most,
            outer_most
        };
        // default is preserver
        template <typename Evaluator, typename OpType, typename PannelParameters, reduction_dim rt> struct PartialReductionKernel
        {
            typedef typename Evaluator::CoeffReturnType CoeffReturnType;
            typedef typename Evaluator::EvaluatorPointerType EvaluatorPointerType;
            typedef typename Evaluator::Index Index;
            typedef OpDefiner<OpType, CoeffReturnType, Index, false> OpDef;
            typedef typename OpDef::type Op;
            typedef cl::sycl::accessor<CoeffReturnType, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local> ScratchAcc;
            ScratchAcc scratch;
            Evaluator evaluator;
            EvaluatorPointerType output_accessor;
            Op op;
            const Index preserve_elements_num_groups;
            const Index reduce_elements_num_groups;
            const Index num_coeffs_to_preserve;
            const Index num_coeffs_to_reduce;

            PartialReductionKernel(ScratchAcc scratch_,
                                   Evaluator evaluator_,
                                   EvaluatorPointerType output_accessor_,
                                   OpType op_,
                                   const Index preserve_elements_num_groups_,
                                   const Index reduce_elements_num_groups_,
                                   const Index num_coeffs_to_preserve_,
                                   const Index num_coeffs_to_reduce_)
                : scratch(scratch_), evaluator(evaluator_), output_accessor(output_accessor_), op(OpDef::get_op(op_)),
                  preserve_elements_num_groups(preserve_elements_num_groups_), reduce_elements_num_groups(reduce_elements_num_groups_),
                  num_coeffs_to_preserve(num_coeffs_to_preserve_), num_coeffs_to_reduce(num_coeffs_to_reduce_)
            {
            }

            EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void element_wise_reduce(Index globalRId, Index globalPId, CoeffReturnType& accumulator)
            {
                if (globalPId >= num_coeffs_to_preserve)
                {
                    return;
                }
                Index global_offset =
                    rt == reduction_dim::outer_most ? globalPId + (globalRId * num_coeffs_to_preserve) : globalRId + (globalPId * num_coeffs_to_reduce);
                Index localOffset = globalRId;

                const Index per_thread_local_stride = PannelParameters::LocalThreadSizeR * reduce_elements_num_groups;
                const Index per_thread_global_stride =
                    rt == reduction_dim::outer_most ? num_coeffs_to_preserve * per_thread_local_stride : per_thread_local_stride;
                for (Index i = globalRId; i < num_coeffs_to_reduce; i += per_thread_local_stride)
                {
                    op.reduce(evaluator.impl().coeff(global_offset), &accumulator);
                    localOffset += per_thread_local_stride;
                    global_offset += per_thread_global_stride;
                }
            }
            EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(cl::sycl::nd_item<1> itemID)
            {
                const Index linearLocalThreadId = itemID.get_local_id(0);
                Index pLocalThreadId = rt == reduction_dim::outer_most ? linearLocalThreadId % PannelParameters::LocalThreadSizeP :
                                                                         linearLocalThreadId / PannelParameters::LocalThreadSizeR;
                Index rLocalThreadId = rt == reduction_dim::outer_most ? linearLocalThreadId / PannelParameters::LocalThreadSizeP :
                                                                         linearLocalThreadId % PannelParameters::LocalThreadSizeR;
                const Index pGroupId =
                    rt == reduction_dim::outer_most ? itemID.get_group(0) % preserve_elements_num_groups : itemID.get_group(0) / reduce_elements_num_groups;
                const Index rGroupId =
                    rt == reduction_dim::outer_most ? itemID.get_group(0) / preserve_elements_num_groups : itemID.get_group(0) % reduce_elements_num_groups;

                Index globalPId = pGroupId * PannelParameters::LocalThreadSizeP + pLocalThreadId;
                const Index globalRId = rGroupId * PannelParameters::LocalThreadSizeR + rLocalThreadId;
                auto scratchPtr = scratch.get_pointer().get();
                auto outPtr = output_accessor.get_pointer() + (reduce_elements_num_groups > 1 ? rGroupId * num_coeffs_to_preserve : 0);
                CoeffReturnType accumulator = op.initialize();

                element_wise_reduce(globalRId, globalPId, accumulator);

                accumulator = OpDef::finalise_op(op.finalize(accumulator), num_coeffs_to_reduce);
                scratchPtr[pLocalThreadId + rLocalThreadId * (PannelParameters::LocalThreadSizeP + PannelParameters::BC)] = accumulator;
                if (rt == reduction_dim::inner_most)
                {
                    pLocalThreadId = linearLocalThreadId % PannelParameters::LocalThreadSizeP;
                    rLocalThreadId = linearLocalThreadId / PannelParameters::LocalThreadSizeP;
                    globalPId = pGroupId * PannelParameters::LocalThreadSizeP + pLocalThreadId;
                }

                /* Apply the reduction operation between the current local
     * id and the one on the other half of the vector. */
                auto out_scratch_ptr = scratchPtr + (pLocalThreadId + (rLocalThreadId * (PannelParameters::LocalThreadSizeP + PannelParameters::BC)));
                itemID.barrier(cl::sycl::access::fence_space::local_space);
                if (rt == reduction_dim::inner_most)
                {
                    accumulator = *out_scratch_ptr;
                }
                // The Local LocalThreadSizeR is always power of 2
                EIGEN_UNROLL_LOOP
                for (Index offset = PannelParameters::LocalThreadSizeR >> 1; offset > 0; offset >>= 1)
                {
                    if (rLocalThreadId < offset)
                    {
                        op.reduce(out_scratch_ptr[(PannelParameters::LocalThreadSizeP + PannelParameters::BC) * offset], &accumulator);
                        // The result has already been divided for mean reducer in the
                        // previous reduction so no need to divide furthermore
                        *out_scratch_ptr = op.finalize(accumulator);
                    }
                    /* All threads collectively read from global memory into local.
       * The barrier ensures all threads' IO is resolved before
       * execution continues (strictly speaking, all threads within
       * a single work-group - there is no co-ordination between
       * work-groups, only work-items). */
                    itemID.barrier(cl::sycl::access::fence_space::local_space);
                }

                if (rLocalThreadId == 0 && (globalPId < num_coeffs_to_preserve))
                {
                    outPtr[globalPId] = op.finalize(accumulator);
                }
            }
        };

        template <typename OutScalar, typename Index, typename InputAccessor, typename OutputAccessor, typename OpType> struct SecondStepPartialReduction
        {
            typedef OpDefiner<OpType, OutScalar, Index, false> OpDef;
            typedef typename OpDef::type Op;
            typedef cl::sycl::accessor<OutScalar, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local> ScratchAccessor;
            InputAccessor input_accessor;
            OutputAccessor output_accessor;
            Op op;
            const Index num_coeffs_to_preserve;
            const Index num_coeffs_to_reduce;

            EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE SecondStepPartialReduction(ScratchAccessor,
                                                                             InputAccessor input_accessor_,
                                                                             OutputAccessor output_accessor_,
                                                                             OpType op_,
                                                                             const Index num_coeffs_to_preserve_,
                                                                             const Index num_coeffs_to_reduce_)
                : input_accessor(input_accessor_), output_accessor(output_accessor_), op(OpDef::get_op(op_)), num_coeffs_to_preserve(num_coeffs_to_preserve_),
                  num_coeffs_to_reduce(num_coeffs_to_reduce_)
            {
            }

            EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(cl::sycl::nd_item<1> itemID)
            {
                const Index globalId = itemID.get_global_id(0);

                if (globalId >= num_coeffs_to_preserve)
                    return;

                auto in_ptr = input_accessor.get_pointer() + globalId;

                OutScalar accumulator = op.initialize();
                // num_coeffs_to_reduce is not bigger that 256
                for (Index i = 0; i < num_coeffs_to_reduce; i++)
                {
                    op.reduce(*in_ptr, &accumulator);
                    in_ptr += num_coeffs_to_preserve;
                }
                output_accessor.get_pointer()[globalId] = op.finalize(accumulator);
            }
        };  // namespace internal

        template <typename Index, Index LTP, Index LTR, bool BC_> struct ReductionPannel
        {
            static EIGEN_CONSTEXPR Index LocalThreadSizeP = LTP;
            static EIGEN_CONSTEXPR Index LocalThreadSizeR = LTR;
            static EIGEN_CONSTEXPR bool BC = BC_;
        };

        template <typename Self, typename Op, TensorSycl::internal::reduction_dim rt> struct PartialReducerLauncher
        {
            typedef typename Self::EvaluatorPointerType EvaluatorPointerType;
            typedef typename Self::CoeffReturnType CoeffReturnType;
            typedef typename Self::Storage Storage;
            typedef typename Self::Index Index;
            typedef ReductionPannel<typename Self::Index, EIGEN_SYCL_LOCAL_THREAD_DIM0, EIGEN_SYCL_LOCAL_THREAD_DIM1, true> PannelParameters;

            typedef PartialReductionKernel<Self, Op, PannelParameters, rt> SyclReducerKerneType;

            static bool run(const Self& self,
                            const Op& reducer,
                            const Eigen::SyclDevice& dev,
                            EvaluatorPointerType output,
                            Index num_coeffs_to_reduce,
                            Index num_coeffs_to_preserve)
            {
                Index roundUpP = roundUp(num_coeffs_to_preserve, PannelParameters::LocalThreadSizeP);

                // getPowerOfTwo makes sure local range is power of 2 and <=
                // maxSyclThreadPerBlock this will help us to avoid extra check on the
                // kernel
                static_assert(!((PannelParameters::LocalThreadSizeP * PannelParameters::LocalThreadSizeR) &
                                (PannelParameters::LocalThreadSizeP * PannelParameters::LocalThreadSizeR - 1)),
                              "The Local thread size must be a power of 2 for the reduction "
                              "operation");

                EIGEN_CONSTEXPR Index localRange = PannelParameters::LocalThreadSizeP * PannelParameters::LocalThreadSizeR;
                // In this step, we force the code not to be more than 2-step reduction:
                // Our empirical research shows that if each thread reduces at least 64
                // elemnts individually, we get better performance. However, this can change
                // on different platforms. In this step we force the code not to be
                // morthan step reduction: Our empirical research shows that for inner_most
                // dim reducer, it is better to have 8 group in a reduce dimension for sizes
                // > 1024 to achieve the best performance.
                const Index reductionPerThread = 64;
                Index cu = dev.getPowerOfTwo(dev.getNumSyclMultiProcessors(), true);
                const Index pNumGroups = roundUpP / PannelParameters::LocalThreadSizeP;
                Index rGroups = (cu + pNumGroups - 1) / pNumGroups;
                const Index rNumGroups = num_coeffs_to_reduce > reductionPerThread * localRange ? std::min(rGroups, localRange) : 1;
                const Index globalRange = pNumGroups * rNumGroups * localRange;

                EIGEN_CONSTEXPR Index scratchSize = PannelParameters::LocalThreadSizeR * (PannelParameters::LocalThreadSizeP + PannelParameters::BC);
                auto thread_range = cl::sycl::nd_range<1>(cl::sycl::range<1>(globalRange), cl::sycl::range<1>(localRange));
                if (rNumGroups > 1)
                {
                    CoeffReturnType* temp_pointer =
                        static_cast<CoeffReturnType*>(dev.allocate_temp(num_coeffs_to_preserve * rNumGroups * sizeof(CoeffReturnType)));
                    EvaluatorPointerType temp_accessor = dev.get(temp_pointer);
                    dev.template unary_kernel_launcher<CoeffReturnType, SyclReducerKerneType>(
                        self, temp_accessor, thread_range, scratchSize, reducer, pNumGroups, rNumGroups, num_coeffs_to_preserve, num_coeffs_to_reduce);

                    typedef SecondStepPartialReduction<CoeffReturnType, Index, EvaluatorPointerType, EvaluatorPointerType, Op> SecondStepPartialReductionKernel;

                    dev.template unary_kernel_launcher<CoeffReturnType, SecondStepPartialReductionKernel>(
                        temp_accessor,
                        output,
                        cl::sycl::nd_range<1>(cl::sycl::range<1>(pNumGroups * localRange), cl::sycl::range<1>(localRange)),
                        Index(1),
                        reducer,
                        num_coeffs_to_preserve,
                        rNumGroups);

                    self.device().deallocate_temp(temp_pointer);
                }
                else
                {
                    dev.template unary_kernel_launcher<CoeffReturnType, SyclReducerKerneType>(
                        self, output, thread_range, scratchSize, reducer, pNumGroups, rNumGroups, num_coeffs_to_preserve, num_coeffs_to_reduce);
                }
                return false;
            }
        };
    }  // namespace internal
}  // namespace TensorSycl

namespace internal {

    template <typename Self, typename Op, bool Vectorizable> struct FullReducer<Self, Op, Eigen::SyclDevice, Vectorizable>
    {
        typedef typename Self::CoeffReturnType CoeffReturnType;
        typedef typename Self::EvaluatorPointerType EvaluatorPointerType;
        static EIGEN_CONSTEXPR bool HasOptimizedImplementation = true;
        static EIGEN_CONSTEXPR int PacketSize = Self::PacketAccess ? Self::PacketSize : 1;
        static void run(const Self& self, Op& reducer, const Eigen::SyclDevice& dev, EvaluatorPointerType data)
        {
            typedef typename conditional<Self::PacketAccess, typename Self::PacketReturnType, CoeffReturnType>::type OutType;
            static_assert(!((EIGEN_SYCL_LOCAL_THREAD_DIM0 * EIGEN_SYCL_LOCAL_THREAD_DIM1) & (EIGEN_SYCL_LOCAL_THREAD_DIM0 * EIGEN_SYCL_LOCAL_THREAD_DIM1 - 1)),
                          "The Local thread size must be a power of 2 for the reduction "
                          "operation");
            EIGEN_CONSTEXPR Index local_range = EIGEN_SYCL_LOCAL_THREAD_DIM0 * EIGEN_SYCL_LOCAL_THREAD_DIM1;

            typename Self::Index inputSize = self.impl().dimensions().TotalSize();
            // In this step we force the code not to be more than 2-step reduction:
            // Our empirical research shows that if each thread reduces at least 512
            // elemnts individually, we get better performance.
            const Index reductionPerThread = 2048;
            // const Index num_work_group =
            Index reductionGroup = dev.getPowerOfTwo((inputSize + (reductionPerThread * local_range - 1)) / (reductionPerThread * local_range), true);
            const Index num_work_group = std::min(reductionGroup, local_range);
            // 1
            // ? local_range
            // : 1);
            const Index global_range = num_work_group * local_range;

            auto thread_range = cl::sycl::nd_range<1>(cl::sycl::range<1>(global_range), cl::sycl::range<1>(local_range));
            typedef TensorSycl::internal::FullReductionKernelFunctor<Self, Op, local_range> reduction_kernel_t;
            if (num_work_group > 1)
            {
                CoeffReturnType* temp_pointer = static_cast<CoeffReturnType*>(dev.allocate_temp(num_work_group * sizeof(CoeffReturnType)));
                typename Self::EvaluatorPointerType tmp_global_accessor = dev.get(temp_pointer);
                dev.template unary_kernel_launcher<OutType, reduction_kernel_t>(self, tmp_global_accessor, thread_range, local_range, inputSize, reducer);

                typedef TensorSycl::internal::SecondStepFullReducer<CoeffReturnType, Op, EvaluatorPointerType, EvaluatorPointerType, Index, local_range>
                    GenericRKernel;
                dev.template unary_kernel_launcher<CoeffReturnType, GenericRKernel>(
                    tmp_global_accessor,
                    data,
                    cl::sycl::nd_range<1>(cl::sycl::range<1>(num_work_group), cl::sycl::range<1>(num_work_group)),
                    num_work_group,
                    reducer);

                dev.deallocate_temp(temp_pointer);
            }
            else
            {
                dev.template unary_kernel_launcher<OutType, reduction_kernel_t>(self, data, thread_range, local_range, inputSize, reducer);
            }
        }
    };
    // vectorizable inner_most most dim preserver
    // col reduction
    template <typename Self, typename Op> struct OuterReducer<Self, Op, Eigen::SyclDevice>
    {
        static EIGEN_CONSTEXPR bool HasOptimizedImplementation = true;

        static bool run(const Self& self,
                        const Op& reducer,
                        const Eigen::SyclDevice& dev,
                        typename Self::EvaluatorPointerType output,
                        typename Self::Index num_coeffs_to_reduce,
                        typename Self::Index num_coeffs_to_preserve)
        {
            return ::Eigen::TensorSycl::internal::PartialReducerLauncher<Self, Op, ::Eigen::TensorSycl::internal::reduction_dim::outer_most>::run(
                self, reducer, dev, output, num_coeffs_to_reduce, num_coeffs_to_preserve);
        }
    };
    // row reduction
    template <typename Self, typename Op> struct InnerReducer<Self, Op, Eigen::SyclDevice>
    {
        static EIGEN_CONSTEXPR bool HasOptimizedImplementation = true;

        static bool run(const Self& self,
                        const Op& reducer,
                        const Eigen::SyclDevice& dev,
                        typename Self::EvaluatorPointerType output,
                        typename Self::Index num_coeffs_to_reduce,
                        typename Self::Index num_coeffs_to_preserve)
        {
            return ::Eigen::TensorSycl::internal::PartialReducerLauncher<Self, Op, ::Eigen::TensorSycl::internal::reduction_dim::inner_most>::run(
                self, reducer, dev, output, num_coeffs_to_reduce, num_coeffs_to_preserve);
        }
    };

    // ArmgMax uses this kernel for partial reduction//
    // TODO(@mehdi.goli) come up with a better kernel
    // generic partial reduction
    template <typename Self, typename Op> struct GenericReducer<Self, Op, Eigen::SyclDevice>
    {
        static EIGEN_CONSTEXPR bool HasOptimizedImplementation = false;
        static bool run(const Self& self,
                        const Op& reducer,
                        const Eigen::SyclDevice& dev,
                        typename Self::EvaluatorPointerType output,
                        typename Self::Index num_values_to_reduce,
                        typename Self::Index num_coeffs_to_preserve)
        {
            typename Self::Index range, GRange, tileSize;
            dev.parallel_for_setup(num_coeffs_to_preserve, tileSize, range, GRange);

            dev.template unary_kernel_launcher<typename Self::CoeffReturnType, TensorSycl::internal::GenericNondeterministicReducer<Self, Op>>(
                self,
                output,
                cl::sycl::nd_range<1>(cl::sycl::range<1>(GRange), cl::sycl::range<1>(tileSize)),
                Index(1),
                reducer,
                range,
                (num_values_to_reduce != 0) ? num_values_to_reduce : static_cast<Index>(1));
            return false;
        }
    };

}  // namespace internal
}  // namespace Eigen

#endif  // UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSOR_REDUCTION_SYCL_HPP
